{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import numpy as np\n",
    "from collections import OrderedDict\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "#you may install tflite by pip install tflite"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tflite import Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "b'main'"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data = open(\"./M-LSD_512_tiny_fp32.tflite\", \"rb\").read()\n",
    "tf_model = Model.GetRootAsModel(data, 0)\n",
    "subgraph = tf_model.Subgraphs(0)\n",
    "subgraph.Name()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_shape(tensor):\n",
    "    return [tensor.Shape(i) for i in range(tensor.ShapeLength())]\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1\n",
      "WireFrameModel_with_alpha/Conv1_with_alpha/Conv2D [32, 3, 3, 4]\n",
      "WireFrameModel_with_alpha/bn_Conv1/FusedBatchNormV3 [32]\n",
      "2\n",
      "WireFrameModel_with_alpha/expanded_conv_depthwise_BN/FusedBatchNormV3;WireFrameModel_with_alpha/expanded_conv_depthwise/depthwise;WireFrameModel_with_alpha/decoder_fpn/upblock_1/conv_bn__act_4/conv/Conv2D [1, 3, 3, 32]\n",
      "WireFrameModel_with_alpha/expanded_conv_depthwise_BN/FusedBatchNormV3 [32]\n",
      "1\n",
      "WireFrameModel_with_alpha/expanded_conv_project/Conv2D [16, 1, 1, 32]\n",
      "WireFrameModel_with_alpha/expanded_conv_project_BN/FusedBatchNormV3;WireFrameModel_with_alpha/Decoder/conv2d/Conv2D;WireFrameModel_with_alpha/expanded_conv_project/Conv2D [16]\n",
      "1\n",
      "WireFrameModel_with_alpha/block_1_expand/Conv2D [96, 1, 1, 16]\n",
      "WireFrameModel_with_alpha/block_1_expand_BN/FusedBatchNormV3 [96]\n",
      "22\n",
      "2\n",
      "WireFrameModel_with_alpha/block_1_depthwise_BN/FusedBatchNormV3;WireFrameModel_with_alpha/block_1_depthwise/depthwise [1, 3, 3, 96]\n",
      "WireFrameModel_with_alpha/block_1_depthwise_BN/FusedBatchNormV3 [96]\n",
      "1\n",
      "WireFrameModel_with_alpha/block_1_project/Conv2D [24, 1, 1, 96]\n",
      "WireFrameModel_with_alpha/block_1_project_BN/FusedBatchNormV3;WireFrameModel_with_alpha/block_2_project/Conv2D;WireFrameModel_with_alpha/block_1_project/Conv2D [24]\n",
      "1\n",
      "WireFrameModel_with_alpha/block_2_expand/Conv2D [144, 1, 1, 24]\n",
      "WireFrameModel_with_alpha/block_2_expand_BN/FusedBatchNormV3 [144]\n",
      "2\n",
      "WireFrameModel_with_alpha/block_2_depthwise_BN/FusedBatchNormV3;WireFrameModel_with_alpha/block_2_depthwise/depthwise;WireFrameModel_with_alpha/block_3_depthwise/depthwise [1, 3, 3, 144]\n",
      "WireFrameModel_with_alpha/block_2_depthwise_BN/FusedBatchNormV3 [144]\n",
      "1\n",
      "WireFrameModel_with_alpha/block_2_project/Conv2D [24, 1, 1, 144]\n",
      "WireFrameModel_with_alpha/block_2_project_BN/FusedBatchNormV3;WireFrameModel_with_alpha/block_2_project/Conv2D [24]\n",
      "11\n",
      "1\n",
      "WireFrameModel_with_alpha/block_3_expand/Conv2D [144, 1, 1, 24]\n",
      "WireFrameModel_with_alpha/block_3_expand_BN/FusedBatchNormV3 [144]\n",
      "22\n",
      "2\n",
      "WireFrameModel_with_alpha/block_3_depthwise_BN/FusedBatchNormV3;WireFrameModel_with_alpha/block_3_depthwise/depthwise [1, 3, 3, 144]\n",
      "WireFrameModel_with_alpha/block_3_depthwise_BN/FusedBatchNormV3 [144]\n",
      "1\n",
      "WireFrameModel_with_alpha/block_3_project/Conv2D [32, 1, 1, 144]\n",
      "WireFrameModel_with_alpha/block_3_project_BN/FusedBatchNormV3;WireFrameModel_with_alpha/decoder_fpn/upblock_1/conv_bn__act_4/conv/Conv2D;WireFrameModel_with_alpha/block_3_project/Conv2D [32]\n",
      "1\n",
      "WireFrameModel_with_alpha/block_4_expand/Conv2D [192, 1, 1, 32]\n",
      "WireFrameModel_with_alpha/block_4_expand_BN/FusedBatchNormV3 [192]\n",
      "2\n",
      "WireFrameModel_with_alpha/block_4_depthwise_BN/FusedBatchNormV3;WireFrameModel_with_alpha/block_4_depthwise/depthwise;WireFrameModel_with_alpha/block_6_depthwise/depthwise [1, 3, 3, 192]\n",
      "WireFrameModel_with_alpha/block_4_depthwise_BN/FusedBatchNormV3 [192]\n",
      "1\n",
      "WireFrameModel_with_alpha/block_4_project/Conv2D [32, 1, 1, 192]\n",
      "WireFrameModel_with_alpha/block_4_project_BN/FusedBatchNormV3;WireFrameModel_with_alpha/decoder_fpn/upblock_1/conv_bn__act_4/conv/Conv2D;WireFrameModel_with_alpha/block_4_project/Conv2D [32]\n",
      "11\n",
      "1\n",
      "WireFrameModel_with_alpha/block_5_expand/Conv2D [192, 1, 1, 32]\n",
      "WireFrameModel_with_alpha/block_5_expand_BN/FusedBatchNormV3 [192]\n",
      "2\n",
      "WireFrameModel_with_alpha/block_5_depthwise_BN/FusedBatchNormV3;WireFrameModel_with_alpha/block_5_depthwise/depthwise;WireFrameModel_with_alpha/block_6_depthwise/depthwise [1, 3, 3, 192]\n",
      "WireFrameModel_with_alpha/block_5_depthwise_BN/FusedBatchNormV3 [192]\n",
      "1\n",
      "WireFrameModel_with_alpha/block_5_project/Conv2D [32, 1, 1, 192]\n",
      "WireFrameModel_with_alpha/block_5_project_BN/FusedBatchNormV3;WireFrameModel_with_alpha/decoder_fpn/upblock_1/conv_bn__act_4/conv/Conv2D;WireFrameModel_with_alpha/block_5_project/Conv2D [32]\n",
      "11\n",
      "1\n",
      "WireFrameModel_with_alpha/block_6_expand/Conv2D [192, 1, 1, 32]\n",
      "WireFrameModel_with_alpha/block_6_expand_BN/FusedBatchNormV3 [192]\n",
      "22\n",
      "2\n",
      "WireFrameModel_with_alpha/block_6_depthwise_BN/FusedBatchNormV3;WireFrameModel_with_alpha/block_6_depthwise/depthwise [1, 3, 3, 192]\n",
      "WireFrameModel_with_alpha/block_6_depthwise_BN/FusedBatchNormV3 [192]\n",
      "1\n",
      "WireFrameModel_with_alpha/block_6_project/Conv2D [64, 1, 1, 192]\n",
      "WireFrameModel_with_alpha/block_6_project_BN/FusedBatchNormV3;WireFrameModel_with_alpha/Decoder/conv_bn__act_9/conv/Conv2D;WireFrameModel_with_alpha/block_6_project/Conv2D [64]\n",
      "1\n",
      "WireFrameModel_with_alpha/block_7_expand/Conv2D [384, 1, 1, 64]\n",
      "WireFrameModel_with_alpha/block_7_expand_BN/FusedBatchNormV3 [384]\n",
      "2\n",
      "WireFrameModel_with_alpha/block_7_depthwise_BN/FusedBatchNormV3;WireFrameModel_with_alpha/block_7_depthwise/depthwise;WireFrameModel_with_alpha/block_9_depthwise/depthwise [1, 3, 3, 384]\n",
      "WireFrameModel_with_alpha/block_7_depthwise_BN/FusedBatchNormV3 [384]\n",
      "1\n",
      "WireFrameModel_with_alpha/block_7_project/Conv2D [64, 1, 1, 384]\n",
      "WireFrameModel_with_alpha/block_7_project_BN/FusedBatchNormV3;WireFrameModel_with_alpha/Decoder/conv_bn__act_9/conv/Conv2D;WireFrameModel_with_alpha/block_7_project/Conv2D [64]\n",
      "11\n",
      "1\n",
      "WireFrameModel_with_alpha/block_8_expand/Conv2D [384, 1, 1, 64]\n",
      "WireFrameModel_with_alpha/block_8_expand_BN/FusedBatchNormV3 [384]\n",
      "2\n",
      "WireFrameModel_with_alpha/block_8_depthwise_BN/FusedBatchNormV3;WireFrameModel_with_alpha/block_8_depthwise/depthwise;WireFrameModel_with_alpha/block_9_depthwise/depthwise [1, 3, 3, 384]\n",
      "WireFrameModel_with_alpha/block_8_depthwise_BN/FusedBatchNormV3 [384]\n",
      "1\n",
      "WireFrameModel_with_alpha/block_8_project/Conv2D [64, 1, 1, 384]\n",
      "WireFrameModel_with_alpha/block_8_project_BN/FusedBatchNormV3;WireFrameModel_with_alpha/Decoder/conv_bn__act_9/conv/Conv2D;WireFrameModel_with_alpha/block_8_project/Conv2D [64]\n",
      "11\n",
      "1\n",
      "WireFrameModel_with_alpha/block_9_expand/Conv2D [384, 1, 1, 64]\n",
      "WireFrameModel_with_alpha/block_9_expand_BN/FusedBatchNormV3 [384]\n",
      "2\n",
      "WireFrameModel_with_alpha/block_9_depthwise_BN/FusedBatchNormV3;WireFrameModel_with_alpha/block_9_depthwise/depthwise [1, 3, 3, 384]\n",
      "WireFrameModel_with_alpha/block_9_depthwise_BN/FusedBatchNormV3 [384]\n",
      "1\n",
      "WireFrameModel_with_alpha/block_9_project/Conv2D [64, 1, 1, 384]\n",
      "WireFrameModel_with_alpha/block_9_project_BN/FusedBatchNormV3;WireFrameModel_with_alpha/Decoder/conv_bn__act_9/conv/Conv2D;WireFrameModel_with_alpha/block_9_project/Conv2D [64]\n",
      "11\n",
      "15\n",
      "1\n",
      "WireFrameModel_with_alpha/decoder_fpn/upblock/conv_bn__act_1/conv/Conv2D [64, 1, 1, 64]\n",
      "WireFrameModel_with_alpha/decoder_fpn/upblock/conv_bn__act_1/bn/FusedBatchNormV3 [64]\n",
      "1\n",
      "WireFrameModel_with_alpha/decoder_fpn/upblock/conv_bn__act/conv/Conv2D [64, 1, 1, 32]\n",
      "WireFrameModel_with_alpha/decoder_fpn/upblock/conv_bn__act/bn/FusedBatchNormV3 [64]\n",
      "10\n",
      "1\n",
      "WireFrameModel_with_alpha/decoder_fpn/upblock/conv_bn__act_2/conv/Conv2D [128, 3, 3, 128]\n",
      "WireFrameModel_with_alpha/decoder_fpn/upblock/conv_bn__act_2/bn/FusedBatchNormV3 [128]\n",
      "11\n",
      "1\n",
      "WireFrameModel_with_alpha/decoder_fpn/upblock/conv_bn__act_3/conv/Conv2D [64, 3, 3, 128]\n",
      "WireFrameModel_with_alpha/decoder_fpn/upblock/conv_bn__act_3/bn/FusedBatchNormV3 [64]\n",
      "15\n",
      "1\n",
      "WireFrameModel_with_alpha/decoder_fpn/upblock_1/conv_bn__act_5/conv/Conv2D [32, 1, 1, 64]\n",
      "WireFrameModel_with_alpha/decoder_fpn/upblock_1/conv_bn__act_5/bn/FusedBatchNormV3 [32]\n",
      "1\n",
      "WireFrameModel_with_alpha/decoder_fpn/upblock_1/conv_bn__act_4/conv/Conv2D [32, 1, 1, 24]\n",
      "WireFrameModel_with_alpha/decoder_fpn/upblock_1/conv_bn__act_4/bn/FusedBatchNormV3 [32]\n",
      "10\n",
      "1\n",
      "WireFrameModel_with_alpha/decoder_fpn/upblock_1/conv_bn__act_6/conv/Conv2D [64, 3, 3, 64]\n",
      "WireFrameModel_with_alpha/decoder_fpn/upblock_1/conv_bn__act_6/bn/FusedBatchNormV3 [64]\n",
      "11\n",
      "1\n",
      "WireFrameModel_with_alpha/decoder_fpn/upblock_1/conv_bn__act_7/conv/Conv2D [64, 3, 3, 64]\n",
      "WireFrameModel_with_alpha/decoder_fpn/upblock_1/conv_bn__act_7/bn/FusedBatchNormV3 [64]\n",
      "1\n",
      "WireFrameModel_with_alpha/Decoder/conv_bn__act_8/conv/Conv2D [64, 3, 3, 64]\n",
      "WireFrameModel_with_alpha/Decoder/conv_bn__act_8/bn/FusedBatchNormV3 [64]\n",
      "1\n",
      "WireFrameModel_with_alpha/Decoder/conv_bn__act_9/conv/Conv2D [64, 3, 3, 64]\n",
      "WireFrameModel_with_alpha/Decoder/conv_bn__act_9/bn/FusedBatchNormV3 [64]\n",
      "1\n",
      "WireFrameModel_with_alpha/Decoder/conv2d/Conv2D [16, 1, 1, 64]\n",
      "WireFrameModel_with_alpha/Decoder/conv2d/BiasAdd;WireFrameModel_with_alpha/Decoder/conv2d/Conv2D;WireFrameModel_with_alpha/Decoder/conv2d/BiasAdd/ReadVariableOp/resource [16]\n",
      "15\n",
      "32\n",
      "0\n",
      "5\n",
      "0\n",
      "0\n",
      "21\n",
      "0\n",
      "34\n",
      "0\n",
      "0\n",
      "0\n",
      "0\n",
      "10\n",
      "32\n"
     ]
    }
   ],
   "source": [
    "tensor_list = []\n",
    "for i in range(3, subgraph.OperatorsLength()):\n",
    "    op = subgraph.Operators(i)\n",
    "    print(op.BuiltinOptionsType())\n",
    "    if op.BuiltinOptionsType() not in [1, 2]: ##conv , depth conv\n",
    "        continue\n",
    "    one_layer_p = []\n",
    "    for j in range(1, op.InputsLength()):\n",
    "        tensor_index = op.Inputs(j)\n",
    "    # use `graph.Tensors(index)` to get the tensor object.\n",
    "        tensor = subgraph.Tensors(tensor_index)\n",
    "        if tensor.Buffer() > 0 and (tensor.Type() == 0 or tensor.Type() == 1):\n",
    "            s = get_shape(tensor)\n",
    "            name = tensor.Name().decode('utf-8')\n",
    "            print(name, s)\n",
    "            layer_type = 'conv'\n",
    "            if   'depthwise' in    name:\n",
    "                layer_type = 'depthwise_conv'\n",
    "            if len(s)==1:\n",
    "                #bias\n",
    "                one_layer_p.append({\n",
    "                    'name': name,\n",
    "                    'tensor': tensor,\n",
    "                    'shape': s,\n",
    "                    'type_maybe': layer_type + '_bias'\n",
    "                })\n",
    "            else:\n",
    "                #weight\n",
    "                one_layer_p.append({\n",
    "                    'name': name,\n",
    "                    'tensor': tensor,\n",
    "                    'shape': s,\n",
    "                    'type_maybe': layer_type + '_weight'\n",
    "                })\n",
    "    if len(one_layer_p) >0:\n",
    "        tensor_list.append(one_layer_p)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append('../')\n",
    "from models.mbv2_mlsd_tiny import MobileV2_MLSD_Tiny\n",
    "model = MobileV2_MLSD_Tiny()\n",
    "model = model.eval()\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "for m in model.modules():\n",
    "    if isinstance(m, nn.Conv2d):\n",
    "        if m.bias is not None:\n",
    "            nn.init.zeros_(m.bias)\n",
    "    if isinstance(m, nn.BatchNorm2d):\n",
    "        nn.init.ones_(m.weight)\n",
    "        nn.init.zeros_(m.bias)\n",
    "        running_var = torch.ones_like(m.running_var)\n",
    "        running_var = torch.sqrt(torch.ones_like(m.running_var) - m.eps)\n",
    "        m.running_var.copy_(m.running_var)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "backbone.features.0.0.weight torch.Size([32, 4, 3, 3])\n",
      "backbone.features.0.1.weight torch.Size([32])\n",
      "backbone.features.0.1.bias torch.Size([32])\n",
      "backbone.features.0.1.running_mean torch.Size([32])\n",
      "backbone.features.0.1.running_var torch.Size([32])\n",
      "backbone.features.0.1.num_batches_tracked torch.Size([])\n",
      "backbone.features.1.conv.0.0.weight torch.Size([32, 1, 3, 3])\n",
      "backbone.features.1.conv.0.1.weight torch.Size([32])\n",
      "backbone.features.1.conv.0.1.bias torch.Size([32])\n",
      "backbone.features.1.conv.0.1.running_mean torch.Size([32])\n",
      "backbone.features.1.conv.0.1.running_var torch.Size([32])\n",
      "backbone.features.1.conv.0.1.num_batches_tracked torch.Size([])\n",
      "backbone.features.1.conv.1.weight torch.Size([16, 32, 1, 1])\n",
      "backbone.features.1.conv.2.weight torch.Size([16])\n",
      "backbone.features.1.conv.2.bias torch.Size([16])\n",
      "backbone.features.1.conv.2.running_mean torch.Size([16])\n",
      "backbone.features.1.conv.2.running_var torch.Size([16])\n",
      "backbone.features.1.conv.2.num_batches_tracked torch.Size([])\n",
      "backbone.features.2.conv.0.0.weight torch.Size([96, 16, 1, 1])\n",
      "backbone.features.2.conv.0.1.weight torch.Size([96])\n",
      "backbone.features.2.conv.0.1.bias torch.Size([96])\n",
      "backbone.features.2.conv.0.1.running_mean torch.Size([96])\n",
      "backbone.features.2.conv.0.1.running_var torch.Size([96])\n",
      "backbone.features.2.conv.0.1.num_batches_tracked torch.Size([])\n",
      "backbone.features.2.conv.1.0.weight torch.Size([96, 1, 3, 3])\n",
      "backbone.features.2.conv.1.1.weight torch.Size([96])\n",
      "backbone.features.2.conv.1.1.bias torch.Size([96])\n",
      "backbone.features.2.conv.1.1.running_mean torch.Size([96])\n",
      "backbone.features.2.conv.1.1.running_var torch.Size([96])\n",
      "backbone.features.2.conv.1.1.num_batches_tracked torch.Size([])\n",
      "backbone.features.2.conv.2.weight torch.Size([24, 96, 1, 1])\n",
      "backbone.features.2.conv.3.weight torch.Size([24])\n",
      "backbone.features.2.conv.3.bias torch.Size([24])\n",
      "backbone.features.2.conv.3.running_mean torch.Size([24])\n",
      "backbone.features.2.conv.3.running_var torch.Size([24])\n",
      "backbone.features.2.conv.3.num_batches_tracked torch.Size([])\n",
      "backbone.features.3.conv.0.0.weight torch.Size([144, 24, 1, 1])\n",
      "backbone.features.3.conv.0.1.weight torch.Size([144])\n",
      "backbone.features.3.conv.0.1.bias torch.Size([144])\n",
      "backbone.features.3.conv.0.1.running_mean torch.Size([144])\n",
      "backbone.features.3.conv.0.1.running_var torch.Size([144])\n",
      "backbone.features.3.conv.0.1.num_batches_tracked torch.Size([])\n",
      "backbone.features.3.conv.1.0.weight torch.Size([144, 1, 3, 3])\n",
      "backbone.features.3.conv.1.1.weight torch.Size([144])\n",
      "backbone.features.3.conv.1.1.bias torch.Size([144])\n",
      "backbone.features.3.conv.1.1.running_mean torch.Size([144])\n",
      "backbone.features.3.conv.1.1.running_var torch.Size([144])\n",
      "backbone.features.3.conv.1.1.num_batches_tracked torch.Size([])\n",
      "backbone.features.3.conv.2.weight torch.Size([24, 144, 1, 1])\n",
      "backbone.features.3.conv.3.weight torch.Size([24])\n",
      "backbone.features.3.conv.3.bias torch.Size([24])\n",
      "backbone.features.3.conv.3.running_mean torch.Size([24])\n",
      "backbone.features.3.conv.3.running_var torch.Size([24])\n",
      "backbone.features.3.conv.3.num_batches_tracked torch.Size([])\n",
      "backbone.features.4.conv.0.0.weight torch.Size([144, 24, 1, 1])\n",
      "backbone.features.4.conv.0.1.weight torch.Size([144])\n",
      "backbone.features.4.conv.0.1.bias torch.Size([144])\n",
      "backbone.features.4.conv.0.1.running_mean torch.Size([144])\n",
      "backbone.features.4.conv.0.1.running_var torch.Size([144])\n",
      "backbone.features.4.conv.0.1.num_batches_tracked torch.Size([])\n",
      "backbone.features.4.conv.1.0.weight torch.Size([144, 1, 3, 3])\n",
      "backbone.features.4.conv.1.1.weight torch.Size([144])\n",
      "backbone.features.4.conv.1.1.bias torch.Size([144])\n",
      "backbone.features.4.conv.1.1.running_mean torch.Size([144])\n",
      "backbone.features.4.conv.1.1.running_var torch.Size([144])\n",
      "backbone.features.4.conv.1.1.num_batches_tracked torch.Size([])\n",
      "backbone.features.4.conv.2.weight torch.Size([32, 144, 1, 1])\n",
      "backbone.features.4.conv.3.weight torch.Size([32])\n",
      "backbone.features.4.conv.3.bias torch.Size([32])\n",
      "backbone.features.4.conv.3.running_mean torch.Size([32])\n",
      "backbone.features.4.conv.3.running_var torch.Size([32])\n",
      "backbone.features.4.conv.3.num_batches_tracked torch.Size([])\n",
      "backbone.features.5.conv.0.0.weight torch.Size([192, 32, 1, 1])\n",
      "backbone.features.5.conv.0.1.weight torch.Size([192])\n",
      "backbone.features.5.conv.0.1.bias torch.Size([192])\n",
      "backbone.features.5.conv.0.1.running_mean torch.Size([192])\n",
      "backbone.features.5.conv.0.1.running_var torch.Size([192])\n",
      "backbone.features.5.conv.0.1.num_batches_tracked torch.Size([])\n",
      "backbone.features.5.conv.1.0.weight torch.Size([192, 1, 3, 3])\n",
      "backbone.features.5.conv.1.1.weight torch.Size([192])\n",
      "backbone.features.5.conv.1.1.bias torch.Size([192])\n",
      "backbone.features.5.conv.1.1.running_mean torch.Size([192])\n",
      "backbone.features.5.conv.1.1.running_var torch.Size([192])\n",
      "backbone.features.5.conv.1.1.num_batches_tracked torch.Size([])\n",
      "backbone.features.5.conv.2.weight torch.Size([32, 192, 1, 1])\n",
      "backbone.features.5.conv.3.weight torch.Size([32])\n",
      "backbone.features.5.conv.3.bias torch.Size([32])\n",
      "backbone.features.5.conv.3.running_mean torch.Size([32])\n",
      "backbone.features.5.conv.3.running_var torch.Size([32])\n",
      "backbone.features.5.conv.3.num_batches_tracked torch.Size([])\n",
      "backbone.features.6.conv.0.0.weight torch.Size([192, 32, 1, 1])\n",
      "backbone.features.6.conv.0.1.weight torch.Size([192])\n",
      "backbone.features.6.conv.0.1.bias torch.Size([192])\n",
      "backbone.features.6.conv.0.1.running_mean torch.Size([192])\n",
      "backbone.features.6.conv.0.1.running_var torch.Size([192])\n",
      "backbone.features.6.conv.0.1.num_batches_tracked torch.Size([])\n",
      "backbone.features.6.conv.1.0.weight torch.Size([192, 1, 3, 3])\n",
      "backbone.features.6.conv.1.1.weight torch.Size([192])\n",
      "backbone.features.6.conv.1.1.bias torch.Size([192])\n",
      "backbone.features.6.conv.1.1.running_mean torch.Size([192])\n",
      "backbone.features.6.conv.1.1.running_var torch.Size([192])\n",
      "backbone.features.6.conv.1.1.num_batches_tracked torch.Size([])\n",
      "backbone.features.6.conv.2.weight torch.Size([32, 192, 1, 1])\n",
      "backbone.features.6.conv.3.weight torch.Size([32])\n",
      "backbone.features.6.conv.3.bias torch.Size([32])\n",
      "backbone.features.6.conv.3.running_mean torch.Size([32])\n",
      "backbone.features.6.conv.3.running_var torch.Size([32])\n",
      "backbone.features.6.conv.3.num_batches_tracked torch.Size([])\n",
      "backbone.features.7.conv.0.0.weight torch.Size([192, 32, 1, 1])\n",
      "backbone.features.7.conv.0.1.weight torch.Size([192])\n",
      "backbone.features.7.conv.0.1.bias torch.Size([192])\n",
      "backbone.features.7.conv.0.1.running_mean torch.Size([192])\n",
      "backbone.features.7.conv.0.1.running_var torch.Size([192])\n",
      "backbone.features.7.conv.0.1.num_batches_tracked torch.Size([])\n",
      "backbone.features.7.conv.1.0.weight torch.Size([192, 1, 3, 3])\n",
      "backbone.features.7.conv.1.1.weight torch.Size([192])\n",
      "backbone.features.7.conv.1.1.bias torch.Size([192])\n",
      "backbone.features.7.conv.1.1.running_mean torch.Size([192])\n",
      "backbone.features.7.conv.1.1.running_var torch.Size([192])\n",
      "backbone.features.7.conv.1.1.num_batches_tracked torch.Size([])\n",
      "backbone.features.7.conv.2.weight torch.Size([64, 192, 1, 1])\n",
      "backbone.features.7.conv.3.weight torch.Size([64])\n",
      "backbone.features.7.conv.3.bias torch.Size([64])\n",
      "backbone.features.7.conv.3.running_mean torch.Size([64])\n",
      "backbone.features.7.conv.3.running_var torch.Size([64])\n",
      "backbone.features.7.conv.3.num_batches_tracked torch.Size([])\n",
      "backbone.features.8.conv.0.0.weight torch.Size([384, 64, 1, 1])\n",
      "backbone.features.8.conv.0.1.weight torch.Size([384])\n",
      "backbone.features.8.conv.0.1.bias torch.Size([384])\n",
      "backbone.features.8.conv.0.1.running_mean torch.Size([384])\n",
      "backbone.features.8.conv.0.1.running_var torch.Size([384])\n",
      "backbone.features.8.conv.0.1.num_batches_tracked torch.Size([])\n",
      "backbone.features.8.conv.1.0.weight torch.Size([384, 1, 3, 3])\n",
      "backbone.features.8.conv.1.1.weight torch.Size([384])\n",
      "backbone.features.8.conv.1.1.bias torch.Size([384])\n",
      "backbone.features.8.conv.1.1.running_mean torch.Size([384])\n",
      "backbone.features.8.conv.1.1.running_var torch.Size([384])\n",
      "backbone.features.8.conv.1.1.num_batches_tracked torch.Size([])\n",
      "backbone.features.8.conv.2.weight torch.Size([64, 384, 1, 1])\n",
      "backbone.features.8.conv.3.weight torch.Size([64])\n",
      "backbone.features.8.conv.3.bias torch.Size([64])\n",
      "backbone.features.8.conv.3.running_mean torch.Size([64])\n",
      "backbone.features.8.conv.3.running_var torch.Size([64])\n",
      "backbone.features.8.conv.3.num_batches_tracked torch.Size([])\n",
      "backbone.features.9.conv.0.0.weight torch.Size([384, 64, 1, 1])\n",
      "backbone.features.9.conv.0.1.weight torch.Size([384])\n",
      "backbone.features.9.conv.0.1.bias torch.Size([384])\n",
      "backbone.features.9.conv.0.1.running_mean torch.Size([384])\n",
      "backbone.features.9.conv.0.1.running_var torch.Size([384])\n",
      "backbone.features.9.conv.0.1.num_batches_tracked torch.Size([])\n",
      "backbone.features.9.conv.1.0.weight torch.Size([384, 1, 3, 3])\n",
      "backbone.features.9.conv.1.1.weight torch.Size([384])\n",
      "backbone.features.9.conv.1.1.bias torch.Size([384])\n",
      "backbone.features.9.conv.1.1.running_mean torch.Size([384])\n",
      "backbone.features.9.conv.1.1.running_var torch.Size([384])\n",
      "backbone.features.9.conv.1.1.num_batches_tracked torch.Size([])\n",
      "backbone.features.9.conv.2.weight torch.Size([64, 384, 1, 1])\n",
      "backbone.features.9.conv.3.weight torch.Size([64])\n",
      "backbone.features.9.conv.3.bias torch.Size([64])\n",
      "backbone.features.9.conv.3.running_mean torch.Size([64])\n",
      "backbone.features.9.conv.3.running_var torch.Size([64])\n",
      "backbone.features.9.conv.3.num_batches_tracked torch.Size([])\n",
      "backbone.features.10.conv.0.0.weight torch.Size([384, 64, 1, 1])\n",
      "backbone.features.10.conv.0.1.weight torch.Size([384])\n",
      "backbone.features.10.conv.0.1.bias torch.Size([384])\n",
      "backbone.features.10.conv.0.1.running_mean torch.Size([384])\n",
      "backbone.features.10.conv.0.1.running_var torch.Size([384])\n",
      "backbone.features.10.conv.0.1.num_batches_tracked torch.Size([])\n",
      "backbone.features.10.conv.1.0.weight torch.Size([384, 1, 3, 3])\n",
      "backbone.features.10.conv.1.1.weight torch.Size([384])\n",
      "backbone.features.10.conv.1.1.bias torch.Size([384])\n",
      "backbone.features.10.conv.1.1.running_mean torch.Size([384])\n",
      "backbone.features.10.conv.1.1.running_var torch.Size([384])\n",
      "backbone.features.10.conv.1.1.num_batches_tracked torch.Size([])\n",
      "backbone.features.10.conv.2.weight torch.Size([64, 384, 1, 1])\n",
      "backbone.features.10.conv.3.weight torch.Size([64])\n",
      "backbone.features.10.conv.3.bias torch.Size([64])\n",
      "backbone.features.10.conv.3.running_mean torch.Size([64])\n",
      "backbone.features.10.conv.3.running_var torch.Size([64])\n",
      "backbone.features.10.conv.3.num_batches_tracked torch.Size([])\n",
      "block12.conv1.0.weight torch.Size([64, 64, 1, 1])\n",
      "block12.conv1.0.bias torch.Size([64])\n",
      "block12.conv1.1.weight torch.Size([64])\n",
      "block12.conv1.1.bias torch.Size([64])\n",
      "block12.conv1.1.running_mean torch.Size([64])\n",
      "block12.conv1.1.running_var torch.Size([64])\n",
      "block12.conv1.1.num_batches_tracked torch.Size([])\n",
      "block12.conv2.0.weight torch.Size([64, 32, 1, 1])\n",
      "block12.conv2.0.bias torch.Size([64])\n",
      "block12.conv2.1.weight torch.Size([64])\n",
      "block12.conv2.1.bias torch.Size([64])\n",
      "block12.conv2.1.running_mean torch.Size([64])\n",
      "block12.conv2.1.running_var torch.Size([64])\n",
      "block12.conv2.1.num_batches_tracked torch.Size([])\n",
      "block13.conv1.0.weight torch.Size([128, 128, 3, 3])\n",
      "block13.conv1.0.bias torch.Size([128])\n",
      "block13.conv1.1.weight torch.Size([128])\n",
      "block13.conv1.1.bias torch.Size([128])\n",
      "block13.conv1.1.running_mean torch.Size([128])\n",
      "block13.conv1.1.running_var torch.Size([128])\n",
      "block13.conv1.1.num_batches_tracked torch.Size([])\n",
      "block13.conv2.0.weight torch.Size([64, 128, 3, 3])\n",
      "block13.conv2.0.bias torch.Size([64])\n",
      "block13.conv2.1.weight torch.Size([64])\n",
      "block13.conv2.1.bias torch.Size([64])\n",
      "block13.conv2.1.running_mean torch.Size([64])\n",
      "block13.conv2.1.running_var torch.Size([64])\n",
      "block13.conv2.1.num_batches_tracked torch.Size([])\n",
      "block14.conv1.0.weight torch.Size([32, 64, 1, 1])\n",
      "block14.conv1.0.bias torch.Size([32])\n",
      "block14.conv1.1.weight torch.Size([32])\n",
      "block14.conv1.1.bias torch.Size([32])\n",
      "block14.conv1.1.running_mean torch.Size([32])\n",
      "block14.conv1.1.running_var torch.Size([32])\n",
      "block14.conv1.1.num_batches_tracked torch.Size([])\n",
      "block14.conv2.0.weight torch.Size([32, 24, 1, 1])\n",
      "block14.conv2.0.bias torch.Size([32])\n",
      "block14.conv2.1.weight torch.Size([32])\n",
      "block14.conv2.1.bias torch.Size([32])\n",
      "block14.conv2.1.running_mean torch.Size([32])\n",
      "block14.conv2.1.running_var torch.Size([32])\n",
      "block14.conv2.1.num_batches_tracked torch.Size([])\n",
      "block15.conv1.0.weight torch.Size([64, 64, 3, 3])\n",
      "block15.conv1.0.bias torch.Size([64])\n",
      "block15.conv1.1.weight torch.Size([64])\n",
      "block15.conv1.1.bias torch.Size([64])\n",
      "block15.conv1.1.running_mean torch.Size([64])\n",
      "block15.conv1.1.running_var torch.Size([64])\n",
      "block15.conv1.1.num_batches_tracked torch.Size([])\n",
      "block15.conv2.0.weight torch.Size([64, 64, 3, 3])\n",
      "block15.conv2.0.bias torch.Size([64])\n",
      "block15.conv2.1.weight torch.Size([64])\n",
      "block15.conv2.1.bias torch.Size([64])\n",
      "block15.conv2.1.running_mean torch.Size([64])\n",
      "block15.conv2.1.running_var torch.Size([64])\n",
      "block15.conv2.1.num_batches_tracked torch.Size([])\n",
      "block16.conv1.0.weight torch.Size([64, 64, 3, 3])\n",
      "block16.conv1.0.bias torch.Size([64])\n",
      "block16.conv1.1.weight torch.Size([64])\n",
      "block16.conv1.1.bias torch.Size([64])\n",
      "block16.conv1.1.running_mean torch.Size([64])\n",
      "block16.conv1.1.running_var torch.Size([64])\n",
      "block16.conv1.1.num_batches_tracked torch.Size([])\n",
      "block16.conv2.0.weight torch.Size([64, 64, 3, 3])\n",
      "block16.conv2.0.bias torch.Size([64])\n",
      "block16.conv2.1.weight torch.Size([64])\n",
      "block16.conv2.1.bias torch.Size([64])\n",
      "block16.conv2.1.running_mean torch.Size([64])\n",
      "block16.conv2.1.running_var torch.Size([64])\n",
      "block16.conv2.1.num_batches_tracked torch.Size([])\n",
      "block16.conv3.weight torch.Size([16, 64, 1, 1])\n",
      "block16.conv3.bias torch.Size([16])\n"
     ]
    }
   ],
   "source": [
    "state_dict = model.state_dict()\n",
    "state_dict_selected = {}\n",
    "for k,v in state_dict.items():\n",
    "    print(k, v.shape)\n",
    "    if len(v.shape) == 4:\n",
    "        state_dict_selected[k] = v\n",
    "    elif 'running_mean' in k:\n",
    "        state_dict_selected[k] = v\n",
    "    ## !! in large model you need to change it to block23\n",
    "    elif 'block16.conv3.bias' in k: \n",
    "        state_dict_selected[k] = v"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "backbone.features.0.0.weight torch.Size([32, 4, 3, 3]) torch.Size([32, 4, 3, 3])\n",
      "backbone.features.0.1.running_mean torch.Size([32]) torch.Size([32])\n",
      "backbone.features.1.conv.0.0.weight torch.Size([32, 1, 3, 3]) torch.Size([32, 1, 3, 3])\n",
      "backbone.features.1.conv.0.1.running_mean torch.Size([32]) torch.Size([32])\n",
      "backbone.features.1.conv.1.weight torch.Size([16, 32, 1, 1]) torch.Size([16, 32, 1, 1])\n",
      "backbone.features.1.conv.2.running_mean torch.Size([16]) torch.Size([16])\n",
      "backbone.features.2.conv.0.0.weight torch.Size([96, 16, 1, 1]) torch.Size([96, 16, 1, 1])\n",
      "backbone.features.2.conv.0.1.running_mean torch.Size([96]) torch.Size([96])\n",
      "backbone.features.2.conv.1.0.weight torch.Size([96, 1, 3, 3]) torch.Size([96, 1, 3, 3])\n",
      "backbone.features.2.conv.1.1.running_mean torch.Size([96]) torch.Size([96])\n",
      "backbone.features.2.conv.2.weight torch.Size([24, 96, 1, 1]) torch.Size([24, 96, 1, 1])\n",
      "backbone.features.2.conv.3.running_mean torch.Size([24]) torch.Size([24])\n",
      "backbone.features.3.conv.0.0.weight torch.Size([144, 24, 1, 1]) torch.Size([144, 24, 1, 1])\n",
      "backbone.features.3.conv.0.1.running_mean torch.Size([144]) torch.Size([144])\n",
      "backbone.features.3.conv.1.0.weight torch.Size([144, 1, 3, 3]) torch.Size([144, 1, 3, 3])\n",
      "backbone.features.3.conv.1.1.running_mean torch.Size([144]) torch.Size([144])\n",
      "backbone.features.3.conv.2.weight torch.Size([24, 144, 1, 1]) torch.Size([24, 144, 1, 1])\n",
      "backbone.features.3.conv.3.running_mean torch.Size([24]) torch.Size([24])\n",
      "backbone.features.4.conv.0.0.weight torch.Size([144, 24, 1, 1]) torch.Size([144, 24, 1, 1])\n",
      "backbone.features.4.conv.0.1.running_mean torch.Size([144]) torch.Size([144])\n",
      "backbone.features.4.conv.1.0.weight torch.Size([144, 1, 3, 3]) torch.Size([144, 1, 3, 3])\n",
      "backbone.features.4.conv.1.1.running_mean torch.Size([144]) torch.Size([144])\n",
      "backbone.features.4.conv.2.weight torch.Size([32, 144, 1, 1]) torch.Size([32, 144, 1, 1])\n",
      "backbone.features.4.conv.3.running_mean torch.Size([32]) torch.Size([32])\n",
      "backbone.features.5.conv.0.0.weight torch.Size([192, 32, 1, 1]) torch.Size([192, 32, 1, 1])\n",
      "backbone.features.5.conv.0.1.running_mean torch.Size([192]) torch.Size([192])\n",
      "backbone.features.5.conv.1.0.weight torch.Size([192, 1, 3, 3]) torch.Size([192, 1, 3, 3])\n",
      "backbone.features.5.conv.1.1.running_mean torch.Size([192]) torch.Size([192])\n",
      "backbone.features.5.conv.2.weight torch.Size([32, 192, 1, 1]) torch.Size([32, 192, 1, 1])\n",
      "backbone.features.5.conv.3.running_mean torch.Size([32]) torch.Size([32])\n",
      "backbone.features.6.conv.0.0.weight torch.Size([192, 32, 1, 1]) torch.Size([192, 32, 1, 1])\n",
      "backbone.features.6.conv.0.1.running_mean torch.Size([192]) torch.Size([192])\n",
      "backbone.features.6.conv.1.0.weight torch.Size([192, 1, 3, 3]) torch.Size([192, 1, 3, 3])\n",
      "backbone.features.6.conv.1.1.running_mean torch.Size([192]) torch.Size([192])\n",
      "backbone.features.6.conv.2.weight torch.Size([32, 192, 1, 1]) torch.Size([32, 192, 1, 1])\n",
      "backbone.features.6.conv.3.running_mean torch.Size([32]) torch.Size([32])\n",
      "backbone.features.7.conv.0.0.weight torch.Size([192, 32, 1, 1]) torch.Size([192, 32, 1, 1])\n",
      "backbone.features.7.conv.0.1.running_mean torch.Size([192]) torch.Size([192])\n",
      "backbone.features.7.conv.1.0.weight torch.Size([192, 1, 3, 3]) torch.Size([192, 1, 3, 3])\n",
      "backbone.features.7.conv.1.1.running_mean torch.Size([192]) torch.Size([192])\n",
      "backbone.features.7.conv.2.weight torch.Size([64, 192, 1, 1]) torch.Size([64, 192, 1, 1])\n",
      "backbone.features.7.conv.3.running_mean torch.Size([64]) torch.Size([64])\n",
      "backbone.features.8.conv.0.0.weight torch.Size([384, 64, 1, 1]) torch.Size([384, 64, 1, 1])\n",
      "backbone.features.8.conv.0.1.running_mean torch.Size([384]) torch.Size([384])\n",
      "backbone.features.8.conv.1.0.weight torch.Size([384, 1, 3, 3]) torch.Size([384, 1, 3, 3])\n",
      "backbone.features.8.conv.1.1.running_mean torch.Size([384]) torch.Size([384])\n",
      "backbone.features.8.conv.2.weight torch.Size([64, 384, 1, 1]) torch.Size([64, 384, 1, 1])\n",
      "backbone.features.8.conv.3.running_mean torch.Size([64]) torch.Size([64])\n",
      "backbone.features.9.conv.0.0.weight torch.Size([384, 64, 1, 1]) torch.Size([384, 64, 1, 1])\n",
      "backbone.features.9.conv.0.1.running_mean torch.Size([384]) torch.Size([384])\n",
      "backbone.features.9.conv.1.0.weight torch.Size([384, 1, 3, 3]) torch.Size([384, 1, 3, 3])\n",
      "backbone.features.9.conv.1.1.running_mean torch.Size([384]) torch.Size([384])\n",
      "backbone.features.9.conv.2.weight torch.Size([64, 384, 1, 1]) torch.Size([64, 384, 1, 1])\n",
      "backbone.features.9.conv.3.running_mean torch.Size([64]) torch.Size([64])\n",
      "backbone.features.10.conv.0.0.weight torch.Size([384, 64, 1, 1]) torch.Size([384, 64, 1, 1])\n",
      "backbone.features.10.conv.0.1.running_mean torch.Size([384]) torch.Size([384])\n",
      "backbone.features.10.conv.1.0.weight torch.Size([384, 1, 3, 3]) torch.Size([384, 1, 3, 3])\n",
      "backbone.features.10.conv.1.1.running_mean torch.Size([384]) torch.Size([384])\n",
      "backbone.features.10.conv.2.weight torch.Size([64, 384, 1, 1]) torch.Size([64, 384, 1, 1])\n",
      "backbone.features.10.conv.3.running_mean torch.Size([64]) torch.Size([64])\n",
      "block12.conv1.0.weight torch.Size([64, 64, 1, 1]) torch.Size([64, 64, 1, 1])\n",
      "block12.conv1.1.running_mean torch.Size([64]) torch.Size([64])\n",
      "block12.conv2.0.weight torch.Size([64, 32, 1, 1]) torch.Size([64, 32, 1, 1])\n",
      "block12.conv2.1.running_mean torch.Size([64]) torch.Size([64])\n",
      "block13.conv1.0.weight torch.Size([128, 128, 3, 3]) torch.Size([128, 128, 3, 3])\n",
      "block13.conv1.1.running_mean torch.Size([128]) torch.Size([128])\n",
      "block13.conv2.0.weight torch.Size([64, 128, 3, 3]) torch.Size([64, 128, 3, 3])\n",
      "block13.conv2.1.running_mean torch.Size([64]) torch.Size([64])\n",
      "block14.conv1.0.weight torch.Size([32, 64, 1, 1]) torch.Size([32, 64, 1, 1])\n",
      "block14.conv1.1.running_mean torch.Size([32]) torch.Size([32])\n",
      "block14.conv2.0.weight torch.Size([32, 24, 1, 1]) torch.Size([32, 24, 1, 1])\n",
      "block14.conv2.1.running_mean torch.Size([32]) torch.Size([32])\n",
      "block15.conv1.0.weight torch.Size([64, 64, 3, 3]) torch.Size([64, 64, 3, 3])\n",
      "block15.conv1.1.running_mean torch.Size([64]) torch.Size([64])\n",
      "block15.conv2.0.weight torch.Size([64, 64, 3, 3]) torch.Size([64, 64, 3, 3])\n",
      "block15.conv2.1.running_mean torch.Size([64]) torch.Size([64])\n",
      "block16.conv1.0.weight torch.Size([64, 64, 3, 3]) torch.Size([64, 64, 3, 3])\n",
      "block16.conv1.1.running_mean torch.Size([64]) torch.Size([64])\n",
      "block16.conv2.0.weight torch.Size([64, 64, 3, 3]) torch.Size([64, 64, 3, 3])\n",
      "block16.conv2.1.running_mean torch.Size([64]) torch.Size([64])\n",
      "block16.conv3.weight torch.Size([16, 64, 1, 1]) torch.Size([16, 64, 1, 1])\n",
      "block16.conv3.bias torch.Size([16]) torch.Size([16])\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "tf_tensor_list = []\n",
    "for i, t in enumerate(tensor_list):\n",
    "    for p in t:\n",
    "        tf_tensor_list.append(p)\n",
    "\n",
    "for i, (k,v) in enumerate(state_dict_selected.items()):\n",
    "    tf_tensor = tf_tensor_list[i]['tensor']\n",
    "    buffer = tf_tensor.Buffer()\n",
    "    shape = get_shape(tf_tensor)\n",
    "    assert(tf_tensor.Type() == 0 or tf_tensor.Type() == 1)  # FLOAT32\n",
    "    \n",
    "    W = np.array(tf_model.Buffers(buffer).DataAsNumpy())\n",
    "    if tensor.Type() == 0:\n",
    "        W = W.view(dtype=np.float32)\n",
    "    elif tensor.Type() == 1:\n",
    "        W = W.view(dtype=np.float16)\n",
    "    W = W.reshape(shape)\n",
    "    if W.ndim == 4:\n",
    "        if W.shape[0] == 1:\n",
    "             W = W.transpose((3, 0, 1, 2))  # depthwise conv\n",
    "        else:\n",
    "             W = W.transpose((0, 3, 1, 2))  #  conv\n",
    "    W_np = torch.from_numpy(W)\n",
    "    if W.ndim ==1 and 'bias' not in k:\n",
    "        #print('bias cvt to bn running mean')\n",
    "        W_np = -W_np\n",
    "    assert   v.shape == W_np.shape\n",
    "    print(k, v.shape, W_np.shape)\n",
    "    state_dict_selected[k] = W_np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<All keys matched successfully>"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "state_dict.update(state_dict_selected)\n",
    "model.load_state_dict(state_dict)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.save(state_dict, 'from_tf.pth')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "import cv2\n",
    "img = cv2.imread('../data/frame_1.jpg')\n",
    "img = cv2.resize(img, (512, 512))\n",
    "img = cv2.cvtColor(img, cv2.COLOR_BGR2RGBA)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "img = ( img / 127.5) - 1.0\n",
    "img = img.transpose(2,0,1)\n",
    "inputs = torch.from_numpy(img).unsqueeze(0)\n",
    "y = model(inputs.float())\n",
    "\n",
    "center_map = y[0][0]\n",
    "center_map = torch.sigmoid(center_map)\n",
    "center_map_np = center_map.detach().numpy()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAQYAAAD8CAYAAACVSwr3AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+j8jraAAAeoElEQVR4nO2da6wkV3HHf3Ufu7Z3/Vq/WHaNbZJFwSjErG5sS04IUQIBK8qSD0RGUbAipOUDKCAlSgz5AJ8iEvFIEBGSIyzsiOA4AcQmckKMRYKUBPDacdYvjNfG4PUu+7Bjrx/rva/Kh+rjPrenZ+7M3O7p0z31k0Yzt2/P9JmePv+uU6eqjqgqjuM4MTNNN8BxnPRwYXAcpwcXBsdxenBhcBynBxcGx3F6cGFwHKeH2oRBRN4pIo+KyEERuamu4ziOUz1SRxyDiMwCPwTeDhwC7gHeq6oPV34wx3Eqpy6L4WrgoKo+oaqLwO3AnpqO5ThOxczV9Lk7gKeivw8B1/Tb+UIRvbymhkwEAc4ETgMrDbfFccqYgXtXOaGqFw2ze13CICXb1oxZRGQvsBfgdcB+Ke7RAgSzuTYDVwKPA8/T+z3a9r2cTiLw42H3rWsocQi4NPp7J3A43kFVb1bVBVVdGErCUkOyxyxmLcxRLgAuCk4KrI62e13CcA+wS0SuEJFNwA3AvoHvaFsHCu1dBV4A/id71pJ9HKdl1DKUUNVlEfkQ8E3snnqLqj5Ux7EaRcmVOAyFXAycDlCXjwFVvRO4s67PTwLt8+w4Lac2YZgqFLcWnE6RVkh02VxGyoT2uig4HSMdiyFM/a3Snk7WlnY6zoikZTG0hTBV6TgdJR2LAdphLRRFIfX2Os4YuMUwLu5XcDpMWhZDivQbMrgoOB0mPYshpbF7v4yPEcNLHadtpGMxKHnuQTDTm+yAQTLj5C4XBGdKSEcYAvPkDr5TNNsZY0ejDx2cKSItYZgDdgOvYKJwEFhsqC2xc9FFwZky0hGGUOzkz4AnscJwnweWqKdjSuF1mQi4IDhTSjrORwHOBd4E/CJwLXBGTceZwSRxM7Ale2zC/BszrBUKx5lC0rIYdgAXYB12E2ZB1HGczZgIXQBsB04Az2L1FE5hVoqXaHOmmHSEYRV4BEvUfgg4ALyI3cGr6qSCOTcvBN6K+TOuzLb/F/Af2XGXKjqesz7BweszPkmRjjCACcHHgeewu/fzVH/BhKHCy9i3vyB7ns+O9zJuLUwSjyBNknSEQbEOeYD8Ylmh2osm+BfAxGcVG1KEs/ACbi04DikJQyDumFXfSYLJ+jJWmvY+4LWYSHwZq6HrJq3jJCgMdZuVq1icxLPAA5iD8zHgR8Byzcd2nJaQnjDUzSo2RFnCrIZvZs8+hHCcV5kuYQj1HlbJpyTd8eU4PaQrDHUUQ/Gqzo4zFGkKQ1Wp12UJUC4GjrMuaQpDYFyroSgsHuLsOCORTq5ETFWmvhQejuMMRZrCAC4KjtMgaQ8lNiIOXk/BccYmbWHYKC4IjjMW3RIGX+/BcSohXR/DRnBRcJwNkb4weFquM42E8gANHr4dzGA1EwadMBcRpyuE5RQaoj3CMAdsA86ia54Rx+ml4ZtcO4QhVJDeDVwGnEdbWu44rWRD914ReRKre7QCLKvqgohsA/4euBwrBP87qvp/w38oa4OSwvMclhF5Gk+RdqaHhhY8quK++6uqepWqLmR/3wTcraq7gLuzv4cnLuu+DdgJnIOVdj8BnMQKrbgvwZkWGojgrcMg3wPcmr2+FXj30O8M6zpswoYLlwK/gAnEEvA0VrB1ERcGZzqYwfpF6BsTPOxGUODfROReEdmbbbtEVY8AZM8Xl71RRPaKyH4R2X8crPbiXkwU5rJPnsGcjWAVpJ/BhhJelzENBPut3N9TD5swy/kSrJr5udjM3ATYqH//OlU9LCIXA3eJyA+GfaOq3gzcDLBwpih/DPwu8HXymoyvYIVaT2BeDBeEdBDswt2BCcMzmHgv49ZcVWwGzsZE4SWsiPGLkzn0hoRBVQ9nz8dE5OvA1cBREdmuqkdEZDtwbN0P2oS5KsNUZOxkfBm3ElIjLNxzEfBL2B3tYay47lF8qFcV52D94ueB41hPOoWJb839YWwjUES2iMjZ4TXwDuBBYB9wY7bbjcA3hmqFYFaBkNdjXMasBl8AJi1mMB/QNcCvA78J3IBdATsx0fBU941zMbaG67uB3wLejk3Xz1L7+d2IxXAJ8HURCZ/zd6r6ryJyD3CHiLwf+AnwnnU/6SXgL7El6n5Kbo6G6C+/+6SFYneup4CD2G8UHMN1rU4+jTyLDR3OwtZY3Qo8nj2C5VATYwuDqj6BzRkUtz8D/NpIH7YEfBe4n94YBb/I0kMxS+4J4D+BI5ip+yQW1RKqcTsb41lMfE9jTshgpX0bG7LVWOU8neDiU9nDSR/FBPwEFqlyd7PN6SwvYIst7wJej83cHcAcvTX7cdIRBsdx1qKYOOzHhtg7sFXTJjBcc2FwnJRZwZZPPIF57J5hIgsluTA4Tsoo5r95HvM5LDKRqfv0hcFnJZxpJ0zfLzGaY3cDfceDWR0ndcJ6qyGeZ9jqThu4oaZvMTj140V00yeIw4Qs6PSFwS/UeoltRj/XaTPB3yf9oYSH1tbLDPlV4MLgZKQvDE59CBPP83cqYAJFW/ySmHaCU8uthfQJv9EEKjql72MAn7KsiqKTMYQ2O+1hlTy7ssZ4hnYIg7NxhF5fgottOwmzEzXiQ4lpIfgSGqo67FRIvOZETQKRvjD4BVwN85h9WOedpt+412eWqiWuVVITPpSYBkJ1rBAkM2mxdXGvlgksX5e+xRDwu87GWCEvqFtVRy3+Ji4Ak6PmoUT6FkPRPPWLb3SUautmxvEPxcKk/vtMhnCeg9+o4vOevjA4g5HC63CR1LkoavjsUabL3OlZHy4MzhriWYawgtc8Vo+x7hLugxJ6+omAx6NUS43in74w1DyWai3BOpjDBOF84E3AYawoa92BSxPwjDvN0R7no7OW0DHngDcCfwB8GtiNrWA0qTtz2XHcKpg8/QR6TOFuhzC41dCfeWz9r1/GqgifQ5q5D3X6PKaZqR1KBNxsXUsIcT4DOzePYzUB/wk4SfOdsJ8l4X6GaonPZYXntj3CAH5BFZnFnIzfBO7BFiX5CWkv6eezE9VTw02zXcLg5ITsulPYmgM/JQ9iSqXTuXUweYrnfMzz3y5h8LtNTogjWCQXidTG8e6YnBwV9432CINfUL2kHnFYZjHEQVgTWB9halDM5zR1wgD5WGoaTdSyIiupU4zEBJs1uSDbfhxb6bzGVZunhoqFtl3CAO3oEFUT54u06fvPYNOpSj7keSdwPXbl3Q58B1tlqU3fK2Wm0mKYVkaxklKypmaAc4GzsDUXBbgC+IVs2xHgQdKYXu0CFZ5DF4Y2UGaSl+0TjzFT6GgzwE5gAYuxOAkcy/63BbgM2Ab8uJHWTRcj3jBcGNrAetmSQRTmyKcsU2AWW7r9Oiyf48fAt7B8js1YO7dibU+lzQ4wREi0iNwiIsdE5MFo2zYRuUtEHsuez8+2i4h8TkQOisgBEdldZ+OnhlUGF1gJY/nNDL+u4SSYBS7CLIM3Y0leh4D7gMeAo+QZoU5SDGMxfAn4PHBbtO0m4G5V/aSI3JT9/SfAu4Bd2eMa4AvZszMs4xalCTENxcIpTfIycBfwIpb9+aPs8bns782Y49GpnxGHlusKg6p+R0QuL2zeA7wte30r8O+YMOwBblNVBb4rIueJyHZVPTJas6aU4uzDsOPC2JpIRRTAhgfHgP/Cpilfzra9AJzGLIVQbs5JinF9DJeEzq6qR0Tk4mz7DuCpaL9D2TYXhmEozj4Mq/KpxDUUp1QVE4CjmOMxRGguYYKwmO3nwlAfY1qgVTsfy0a3pc0Rkb3AXoDXbeRIKXSIqog7eNu+Vz+/RhCCODIv+ExWSGt61XmVcesxHBWR7QDZc5iEOgRcGu23E/NB96CqN6vqgqouXDTq0SewqGcj1F2rsSlWMd/HMmvrRQZxcIuhPsasZTKuMOwDbsxe3wh8I9r+vmx24lrg+Vr9C10ViLYxjLkaL55bfBQ/w6mO0EdmR3vbukMJEfkK5mi8UEQOAR8HPgncISLvxyoAvCfb/U4s4PUg5mr6/dGaMyTTcjG1Zbg0ii+k+Hf4jm35rm0j1AQ9h5E8fWITCM2yIKL7R31TbC10zRQtzk40/xPVRwjOCkFOXfstm2YTFn5+Ccij3KuqC8O8rR01H8voameRkkeT7ZjEZzf9XbuMYoJ7erS3eUh0ShQ7SFmJ9kkK4jA5GuN+Zr/tPktRLYrNCo0YSNZeiyFmWu40k+4wIf9ilmrPcdmyg00stjsNhBmhU6O9rf0WQ5fuMIO+RxOWwmz2qHox3JhYFLruT2mCcH5HLIbTDWGYJR9LtZ0yz32dlAlr+DvkXVRpLYTfqRjhWUOlY4cpKQZbJHi0wwXVBWGAyd81+1ld4U5eR4cdlELuVkP1VJ1ElTzz5OKwhE93jcugsuOT6KhlfgenGsaoB9luYYgLlIS5cBeGapm0b6OJ404DI57Pds9KxFWji8MKZzhSiZXw3zAp2m0xQD4OniH3oLfd1zCpO2cKcQNxzEbbf7cO0X6LIQwd5jB/w4jJIg7N3aGLaeY+ZZkM7bcYwnDiTPI7zmL/3VtBEx2jKcvBRSBJ2i0MscWwmXxIcbKxFrWHsmpLTQ8rnGRo91ACcp/CPJZFdjbuvCpLSuqXqNTmqlFObbTbYoiZx4YT037nK9Y36Me0nh9nKNpvMUAey3AB8HpsWDGNxDkOc+QJUOFXdseeMyTdEIbga5jBlj6bxgVMgoUQxwG4l98Zk/YLQxCF01jSz3z2mHY/Q7/aio4zBO33MSh5vnmYptzUXHNqYZR6iGWxAY4zIu23GMAshleyxzLmhOzGNxs+ZLksUKj4P8cZknZ3n9ABVjBReCl7Pou2f7OcUTt1mSBM+7DKGZludJ9VTBSewZa+eQ22vHpXwqNHqZ5UdY1GZyrphjCAWQ1LmMVwNjac6IowjIOLg7MB2u98DIS6dkuYtdB1YZiW2gW+EE0jdMNiiNdDXMZiGbbSvdmJmPXqFnRhwRoBziCv0uVMjO6c7uCEXMEiH88jvbyJUDOirja1WQSKCCYIu4FdmEPZmRjdEIbQIZYxJyTYhbSVdIRBsIHbZqo766l8tzoI1sJu4E2YyDsTo/0+hnCHDNGPx7G06zlMGOZIoz5DKFu2KXt9ivXrUw4aXxenIruWIanYsshfxc7Vi802Z9povzAEYqthGftmZ2N36LA+QpOEpcLiOpXDvGfQ/7qQRdpP/MJveaTP/51a6cZQIhAuptPkwnAGaXzL2Dk6ijis95ltJQytNjE4t6VLfpMWkUKXqZYVzAQFm7IMpnsKxBf5RqshFztM05WeRzl+GFJtJRdvJym6Jwyr2Jh0GRtGpFYgNnToqsqkpyAO4bsMc/ww27AVOCd7nsY0+cTpjo8B8inLZ7A70gXAucBzWERkKsxj7SumRY/rM2jK1I4X/Annvl/4dlxZahWz6l5h7fDKSYZuCUMgZFnOYsOJM0jHUSfA+VicxRHgedaup5BKO4eh6Dcp8wcUS83FDuL4c5ykWHcoISK3iMgxEXkw2vYJEXlaRO7PHtdH//uoiBwUkUdF5DfqanhfFBOGMEV5JhbTkIqfYQYL2LkGuIxctNpKEId+olBcXcoLyLSCYSyGLwGfB24rbP+sqn4q3iAiVwI3YCEprwW+JSJvUNXJrDEUzPGQTLWImewhDbvpKUuwdlwDvAVzvJ0EniRvWxs7Sr82x45JF4FWsa7FoKrfAZ4d8vP2ALer6mlV/RFwELh6A+0bHcUE4TQmEHNY7kRKbtZLsaK1b8Gsh+AcbdpXEB5V49ZB69jIZfAhETmQDTXOz7btAJ6K9jmUbetBRPaKyH4R2X98A40oZRmzGJawbxhmJ1JAgROYX2ELVjuiSU9PqCq9OXtU2ZY4uS0Fa80ZmnGF4QvAzwBXYS60T2fby0bLpfcJVb1ZVRdUdeGiMRvRl0VMGE5hF/tFpBNrvwx8Gfgr4B+AH2bbR/UzjBM/UMYM+ZocZzDe1K5Q3p7Y/+C0irHuD6p6NLwWkb8B/jn78xBmKAd2AofHbt24rJL7GeaxCz6VtSYUOyPLWKc8TX2rPA87vg+l8ULY9jjHYZ1jOK1iLItBRLZHf/42EGYs9gE3iMhmEbkCG0F/f2NNHIMgDC9h1kMYTqTCKeCnwFHMe7M8ePdX6XdXLiPEF8wx+FeORXRckfLCs51jXYtBRL4CvA24UEQOAR8H3iYiV2GXwZPABwBU9SERuQN4GLvcPzixGYlAMF+DMITU65SKtqxgnTC+o49L2ftDHkLIQRjU4cuSl8Zpk/sQOoWoNi/xCyK6v8oPnMPGzDuA7diA5mHg3ioPskHK7vzr/RTFRWr7vVfInYmKWSjrDREGeYfKFsJ1WofAvaq6MMy+3Yx8DOG5i9ljlfzumcpFvVErIRaWopMvtkRWhjxWsBQoeQ5Vp0LUotN5uikMIUz3BSy46TR5+vVkBzaDKXbGfh047phBAELMQXj/KmunBZeyRz/BgP7HC5mPPwv8HGZ1HcaKpgwrNE6r6aYwwNqq0cvk8/UpCUNg0LRjyEacI+/skDsXN2HDpkUsMSmEgscxBMVjrCdCO4FfBH4TeF32GfuBfxzta42ND10ap7vCALkwLGEX/Dx5FaVU6DeOjwnL2q+Qd+ogdGdgWaQh1TzOVCxzCA4ShfCZPwf8CvBm8rDtWSaTBTnsrItTK90Vhjjzb5VcGFIbTgQGdYBF7HuEtOawf/AhnCIXhjnyzhXeN8x0YpxC/RLwOJayfiY2tfot6hXVol8DPIy6QbotDCvkHSrUQEgxk3G92o5hKBTXbQj+hGWsUOoi9j1nyX0S4X3rESyFEPNwGPhv4InsGMexrJc6pySDv2Q9/4czEborDJB3qiXy6MdgEqfGeuJQvOsH4TtNPvOyGm0P3734vvjOHFdemiW3pp7OHkK+Vkdd52yUoC1nYnRbGMKU5cvYWHkbdnc93WSjKmIWK40m2PebxwrAzGAVq0J5ei28ZxN2Ls7LHoexoUPIRo2HK3VRFKUgQJ6FmQzdFYb4rhqm7bZgY+bnaP/FN4N9H8jrTrwu+/tFyjvYDDYN+QYsWP1c4E5sWndSolAUBFgrCG3/XTpCSlUKqqNoPgdvfhCGLhASn4JDMFgQYahU7ODhLh0shlDV6iS547JOUYhrPhQDs2JrwUmC7loMgTDWXsX8DCnVf9wIK9idfpa87sSzWGJW2VApdMZFrM7k09l+z1H/bENZWraXeEuabgtDnFC1iFkLYTqv7YQy+YEl4AeYv6GskwWBPAk8gtWBOE291bPjmo/FtrggJE13hSFccCvYmPsF8hWw50kzlmFUVguvi+s7FgUwzDDU7XwNghBXzYqFoFjf0sUhObrpY4gJd8pQTn4Lg5dE6xpVpHaPeqwZ8niK4sI6LgatoLvCUMwLCME+Z1HtUvSpE9+h6yQWhSAMZaLgw4dW0O3uEWcehuSjMzBfQ0rL1tXFpDphmaUQn99B7XCRSJJuC0OgWDU63M2caigKQhzeHC9GUxbB6SRJ97tHGEaE0OFQKn0aLIa6CaHUIc8iiEK4qlwUWkt3ZyUCxQzE12Chw8/Rf2rPKSeORYitg+KwIU7wKi5fp6yNZ4jf5yRD9y2GQKh6tAUThlTWmWgTsWMxFoW4o4dIxhBeXczXcGdkK5geYQichcUzbGMav/14xI7FUJI+Tu+G3DIrigLR/z2wqTV0dyhRnC8PdzKwfIG2rDLd9GIuxWnIsrDmIAJBENyf0Hq6Kwz9TNY5zGpIbaHbMuIOOYmyasVjx+nRsSjEohsshXjY4EFMraebwlC0BGKH2VbsAn4Fm51YJH2COExqHchi0lM8/VgUheKUpNMJuikMcaUiyAOcTmDZh8EJeQ42M5Fq3kQ8bp8UsRAEBvkKKPzP6QTdFAbIxSE2d08Cz2D+hTCsSLU4bBMMymkoG0K4EHSW1EfZGyM2dUNhkzB0KHrVp52ihVDs/IMshUHrYjitpNvCEKNY9ONLmECssDZKb5op69T9nIhlloTTObo7lIgJF/EStkbCEuaEDEFP006/AKSytR6K//PhRCfpvjDEd76TwP1YPYZ5hlsFeloZ5G9wMeg83ReGmFXy1ZnCAi2TqFXQNvrlMbggTA3TJQxF3Gm2lkGJTS4KU8X0CUPsNHNhyBm0EpSLwtSxrutNRC4VkW+LyCMi8pCIfDjbvk1E7hKRx7Ln87PtIiKfE5GDInJARHbX/SVGxufgcwaJo5+nqWUYn/wy8Ieq+kbgWuCDInIlcBNwt6ruAu7O/gZ4F7bO0S5gL/CFyls9iGEtAL/o11KWHOVMLesKg6oeUdX7stcvYKsS7AD2ALdmu90KvDt7vQe4TY3vAueJyPbKW963wUP8PyxbF1bCnkbKFpKdVC6GkzwjzeKLyOXAW4DvAZeo6hEw8QAuznbbATwVve1Qti093Mfg4c1OKUMLg4hsBb4KfERVTw7atWRbz2UnIntFZL+I7D8+bCOqZho7Q3HI4DglDCUMIjKPicKXVfVr2eajYYiQPR/Lth8CLo3evhNbbH0Nqnqzqi6o6sJF47beGR23EJwhGGZWQoAvAo+o6meif+0Dbsxe3wh8I9r+vmx24lrg+TDkcBynHQwTx3Ad8HvAAyJyf7btY8AngTtE5P3AT4D3ZP+7E7geOIhVO/j9SltcFcWaDY7jvIqoNm9TLojo/o1+yDjrM4aiJO6Nd6YAgXtVdWGYfac3t7DfEu2O40xhSHQgrh7tOM4a/H7pvgbH6aE7wuA+AsepjO4Ig+M4ldF+H8NGovhCoI8PJRxnDe23GLxTO07ltF8YHMepHBcGcMel4xRotzD4MMJxaqE7zkfwO7/jVES7LYaApxE7TqW0Wxh8KOE4tdDuoYRbCY5TC+22GBzHqYV2C4MPJRynFtotDI7j1EI3fAzua3CcSnGLwXGcHrphMTiOUyluMTiO04MLg+M4PbgwOI7TgwuD4zg9uDA4jtODC4PjOD24MDiO04MLg+M4PbgwOI7TgwuD4zg9uDA4jtODC4PjOD24MDiO00O7syurxkvROw4whMUgIpeKyLdF5BEReUhEPpxt/4SIPC0i92eP66P3fFREDorIoyLyG3V+gVpwUXCmnGEshmXgD1X1PhE5G7hXRO7K/vdZVf1UvLOIXAncALwJeC3wLRF5g6quVNnwWnBBcBxgCItBVY+o6n3Z6xeAR4AdA96yB7hdVU+r6o+Ag8DVVTTWcZzJMJLzUUQuB94CfC/b9CEROSAit4jI+dm2HcBT0dsOUSIkIrJXRPaLyP7jIzfbcZw6GVoYRGQr8FXgI6p6EvgC8DPAVcAR4NNh15K39xjpqnqzqi6o6sJFIzfbcZw6GUoYRGQeE4Uvq+rXAFT1qKquqOoq8Dfkw4VDwKXR23cCh6trsuM4IzPiGizDzEoI8EXgEVX9TLR9e7TbbwMPZq/3ATeIyGYRuQLYBXx/tGY5jlMZwsgRS8PMSlwH/B7wgIjcn237GPBeEbkKGyY8CXwAQFUfEpE7gIexGY0PtmJGwnG6SBCFTcCpEd6m2vwcnYgcB14CTjTdliG4kHa0E9rTVm9n9ZS19TJVHcqll4QwAIjIflVdaLod69GWdkJ72urtrJ6NttVzJRzH6cGFwXGcHlIShpubbsCQtKWd0J62ejurZ0NtTcbH4DhOOqRkMTiOkwiNC4OIvDNLzz4oIjc13Z4iIvKkiDyQpZbvz7ZtE5G7ROSx7Pn89T6nhnbdIiLHROTBaFtpu8T4XHaOD4jI7gTamlza/oASA0md14mUQlDVxh7ALPA48HosBON/gSubbFNJG58ELixs+wvgpuz1TcCfN9CutwK7gQfXaxdwPfAvWLjLtcD3EmjrJ4A/Ktn3yuw62AxckV0fsxNq53Zgd/b6bOCHWXuSOq8D2lnZOW3aYrgaOKiqT6jqInA7lradOnuAW7PXtwLvnnQDVPU7wLOFzf3atQe4TY3vAucVQtprpU9b+9FY2r72LzGQ1Hkd0M5+jHxOmxaGoVK0G0aBfxORe0Vkb7btElU9AvYjARc31rq19GtXqud57LT9uimUGEj2vFZZCiGmaWEYKkW7Ya5T1d3Au4APishbm27QGKR4njeUtl8nJSUG+u5asm1iba26FEJM08KQfIq2qh7Ono8BX8dMsKPBZMyejzXXwjX0a1dy51kTTdsvKzFAgue17lIITQvDPcAuEblCRDZhtSL3NdymVxGRLVmdS0RkC/AOLL18H3BjttuNwDeaaWEP/dq1D3hf5kW/Fng+mMZNkWLafr8SAyR2Xvu1s9JzOgkv6joe1usxr+rjwJ823Z5C216PeXP/F3gotA+4ALgbeCx73tZA276CmYtL2B3h/f3ahZmSf52d4weAhQTa+rdZWw5kF+72aP8/zdr6KPCuCbbzlzAT+wBwf/a4PrXzOqCdlZ1Tj3x0HKeHpocSjuMkiAuD4zg9uDA4jtODC4PjOD24MDiO04MLg+M4PbgwOI7TgwuD4zg9/D9y+yGWAosTeQAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "%matplotlib inline\n",
    "import matplotlib\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "plt.imshow(center_map_np , cmap = 'autumn' , interpolation = 'nearest' )\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "            "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
